"""
[Image]

## For LaVCa
python -m evaluation.eval_sentence_similarity \
    --subject_names subj01 subj02 subj05 subj07 \
    --atlasname streams floc-faces floc-places floc-bodies floc-words \
    --modality image \
    --modality_hparam default \
    --model_name CLIP-ViT-B-32 \
    --reduce_dims default 0 \
    --dataset_name OpenImages \
    --max_samples full \
    --dataset_path ./data/OpenImages/frames_518x518px \
    --dataset_captioner MiniCPM-Llama3-V-2_5 \
    --voxel_selection pvalues_corrected 0.05 \
    --layer_selection best \
    --caption_model MeaCap \
    --keywords_model gpt-4o-2024-08-06 \
    --correct_model default \
    --candidate_num 50 \
    --key_num 5 \
    --temperature 0.05 \
    --filter_th 0.15 \
    --cc_method spearman \
    --device cpu 

## Concat
python -m evaluation.eval_sentence_similarity \
    --subject_names subj01 subj02 subj05 subj07 \
    --atlasname streams floc-faces floc-places floc-bodies floc-words \
    --modality image \
    --modality_hparam default \
    --model_name CLIP-ViT-B-32 \
    --reduce_dims default 0 \
    --dataset_name OpenImages \
    --max_samples full \
    --dataset_path ./data/OpenImages/frames_518x518px \
    --dataset_captioner MiniCPM-Llama3-V-2_5 \
    --voxel_selection pvalues_corrected 0.05 \
    --layer_selection best \
    --caption_model MeaCap \
    --keywords_model gpt-4o-2024-08-06 \
    --correct_model default \
    --candidate_num 50 \
    --key_num -1 \
    --temperature -1 \
    --filter_th -1 \
    --cc_method spearman \
    --device cpu 

## For BrainSCUBA
python -m evaluation.eval_sentence_similarity \
    --subject_names subj01 subj02 subj05 subj07 \
    --atlasname streams floc-faces floc-places floc-bodies floc-words \
    --betas_norm \
    --modality image \
    --modality_hparam default \
    --model_name CLIP-ViT-B-32 \
    --reduce_dims default 0 \
    --dataset_name OpenImages \
    --max_samples full \
    --dataset_path ./data/OpenImages/frames_518x518px \
    --dataset_captioner MiniCPM-Llama3-V-2_5 \
    --voxel_selection pvalues_corrected 0.05 \
    --layer_selection best \
    --caption_model brainscuba  \
    --keywords_model default \
    --correct_model default \
    --candidate_num -1 \
    --temperature -1 \
    --filter_th -1 \
    --key_num -1 \
    --tau 150 \
    --cc_method spearman \
    --device cuda

"""

import torch
import argparse
import os
import json
from tqdm import tqdm
from utils.utils import (
    search_best_layer, make_filename, TrnVal, gen_nulldistrib_gauss, 
    fdr_correction, collect_fmri_byroi_for_nsd, create_volume_index_and_weight_map
)
import numpy as np
from utils.nsd_access import NSDAccess
import scipy
from himalaya.scoring import correlation_score
from sentence_transformers import SentenceTransformer

torch.manual_seed(42)

def load_resp_wholevoxels_for_nsd(subject_name, dataset="all", atlas="streams"):
    resp_trn = collect_fmri_byroi_for_nsd(subject_name,
                                                         trainvalid="TRAIN",
                                                         atlasname=atlas)
    resp_val = collect_fmri_byroi_for_nsd(subject_name,
                                                         trainvalid="VALID",
                                                         atlasname=atlas)

    return TrnVal(trn=resp_trn, val=resp_val)

    
def main(args):
    score_root_path = "./data/nsd/encoding"
    modality = args.modality
    modality_hparam = args.modality_hparam
    model_name = args.model_name
    file_type = args.voxel_selection[0]
    threshold = float(args.voxel_selection[1])
    nsda = NSDAccess('./data/NSD')
    
    sim_func = torch.nn.CosineSimilarity(dim=1, eps=1e-6)
    wte_model_path = "sentence-transformers/all-MiniLM-L6-v2"
    wte_model = SentenceTransformer(wte_model_path).to(args.device)

    # coco_captions = nsda.download_coco_annotation_file()

    for subject_name in args.subject_names:
        print(subject_name)
        filename = make_filename(args.reduce_dims[0:2])

        print(f"Modality: {modality}, Modality hparams: {modality_hparam}, Feature: {model_name}, Filename: {filename}")
        # loading the selected layer per subject
        model_score_dir = f"{score_root_path}/{subject_name}/scores/{modality}/{modality_hparam}/{model_name}"
        if args.layer_selection == "best":
            target_best_cv_layer, _, _ = search_best_layer(model_score_dir, filename, select_topN="all")
        else:
            target_best_cv_layer = args.layer_selection
        print(f"Best layer: {target_best_cv_layer}")

        # Get coco captions
        nsd_expdesign = scipy.io.loadmat('./data/NSD/nsddata/experiments/nsd/nsd_expdesign.mat')

        # Note that most of them are 1-base index!
        # This is why I subtract 1
        sharedix = nsd_expdesign['sharedix'] -1 
        stims = np.load(f'./data/nsd/fmri/{subject_name}/{subject_name}_stims_ave.npy')
        annot_list = []
        tr_idx = np.zeros(len(stims))
        for idx, s in tqdm(enumerate(stims)): 
            if s in sharedix:
                tr_idx[idx] = 0
            else:
                tr_idx[idx] = 1
        annot_list = np.array(nsda.read_image_coco_info(stims))

        # os.makedirs(savedir, exist_ok=True)

        annots_tr = annot_list[tr_idx==1]
        annots_te = annot_list[tr_idx==0]
        annotator_num = len(annots_te[0])
        os.makedirs(f"./data/nsd/coco_annots", exist_ok=True)
        annots_embs_tr_all = []
        annots_embs_te_all = []
        for i in range(annotator_num):
            if os.path.exists(f"./data/nsd/coco_annots/{subject_name}_coco_embs_te_{i}.npy"):
                coco_embs_tr_i = np.load(f"./data/nsd/coco_annots/{subject_name}_coco_embs_tr_{i}.npy")
                coco_embs_te_i = np.load(f"./data/nsd/coco_annots/{subject_name}_coco_embs_te_{i}.npy")
                annots_embs_tr_all.append(coco_embs_tr_i)
                annots_embs_te_all.append(coco_embs_te_i)
                print(f"Already processed: {subject_name}_coco_embs")
                continue
            
            annots_tr_i = [annot[i]["caption"].replace("\n", "") for annot in annots_tr]
            annots_te_i = [annot[i]["caption"].replace("\n", "") for annot in annots_te]
            print(f"Annots_tr: {len(annots_tr_i)}, Annots_te: {len(annots_te_i)}")
            # np.save(f"./data/nsd/coco_annots/{subject_name}_coco_annots_tr_{i}.npy", annots_tr)
            # np.save(f"./data/nsd/coco_annots/{subject_name}_coco_annots_te_{i}.npy", annots_te)
            # Save as txt
            with open(f"./data/nsd/coco_annots/{subject_name}_coco_annots_tr_{i}.txt", "w") as f:
                for annot in annots_tr_i:
                    f.write(annot + "\n")
            with open(f"./data/nsd/coco_annots/{subject_name}_coco_annots_te_{i}.txt", "w") as f:
                for annot in annots_te_i:
                    f.write(annot + "\n")
            # Precompute the embeddings of coco captions
            coco_embs_tr_i = wte_model.encode(annots_tr_i)
            coco_embs_te_i = wte_model.encode(annots_te_i)
            annots_embs_tr_all.append(coco_embs_tr_i)
            annots_embs_te_all.append(coco_embs_te_i)
            print(coco_embs_tr_i.shape, coco_embs_te_i.shape)
            np.save(f"./data/nsd/coco_annots/{subject_name}_coco_embs_tr_{i}.npy", coco_embs_tr_i)
            np.save(f"./data/nsd/coco_annots/{subject_name}_coco_embs_te_{i}.npy", coco_embs_te_i)

        volume_index, weight_index_map, target_top_voxels = create_volume_index_and_weight_map(
            subject_name=subject_name,
            file_type=file_type,
            threshold=threshold,
            model_score_dir=model_score_dir,
            target_best_cv_layer=target_best_cv_layer,
            filename=filename,
            nsda=nsda,
            atlasnames=args.atlasname  # args.atlasname がリストであることを想定
        )

        stim_root_path = "./data/stim_features/nsd"
        if args.reduce_dims[0] != "default":
            try:
                reducer_proj_path = f"{stim_root_path}/{modality}/{modality_hparam}/{model_name}/{target_best_cv_layer}/projector_{subject_name}_ave_{filename}.npy"
                reducer_projector = np.load(reducer_proj_path, allow_pickle=True).item()
            except:
                reducer_proj_path = f"{stim_root_path}/{modality}/{modality_hparam}/{model_name}/{target_best_cv_layer}/projector_{subject_name}_ave_{filename}.pkl"
                reducer_projector = np.load(reducer_proj_path, allow_pickle=True)
        else:
            reducer_projector = None
        print(reducer_projector)

        resp = load_resp_wholevoxels_for_nsd(subject_name, "all", atlas="cortex")

        cc_dict_original = {"tr": [], "te": []}
        cc_dict_shuffled = {"tr": [], "te": []}
        caption_embs_list = []
        resp_trn_list = []
        resp_val_list = []
        weight_indices = []

        for idx, voxel_index in enumerate(tqdm(volume_index)):
            try:
                print(f"voxel_index: {voxel_index}")
                vindex_pad = str(voxel_index).zfill(5)
                resp_save_path = f"./data/nsd/insilico/{subject_name}/{args.dataset_name}_{args.max_samples}/{modality}/{modality_hparam}/{model_name}_{make_filename(args.reduce_dims[0:2])}/whole/voxel{vindex_pad}"
            
                print(f"Now processing: {voxel_index}")
                weight_index = weight_index_map[voxel_index]
                resp_trn = resp.trn[:,weight_index]
                resp_val = resp.val[:,weight_index]
                resp_trn_list.append(resp_trn)
                resp_val_list.append(resp_val)
                weight_indices.append(weight_index)

                if args.caption_model in ["MeaCap", "default"]:
                    cc_sentence_sim_tr_file_path = f"{resp_save_path}/{args.cc_method}_cc_sentence_sim_{args.caption_model}_kmodel_{args.keywords_model}_{args.key_num}keys_{args.temperature}temp_{args.filter_th}th_{args.candidate_num}cands_cmodel_{args.correct_model}_tr.npy"
                    cc_sentence_sim_te_file_path = f"{resp_save_path}/{args.cc_method}_cc_sentence_sim_{args.caption_model}_kmodel_{args.keywords_model}_{args.key_num}keys_{args.temperature}temp_{args.filter_th}th_{args.candidate_num}cands_cmodel_{args.correct_model}_te.npy"

                else:
                    if args.betas_norm:
                        cc_sentence_sim_tr_file_path = f"{resp_save_path}/{args.cc_method}_cc_sentence_sim_{args.caption_model}_tau{args.tau}_betanorm_tr.npy"
                        cc_sentence_sim_te_file_path = f"{resp_save_path}/{args.cc_method}_cc_sentence_sim_{args.caption_model}_tau{args.tau}_betanorm_te.npy"
                    else:
                        cc_sentence_sim_tr_file_path = f"{resp_save_path}/{args.cc_method}_cc_sentence_sim_{args.caption_model}_tau{args.tau}_tr.npy"
                        cc_sentence_sim_te_file_path = f"{resp_save_path}/{args.cc_method}_cc_sentence_sim_{args.caption_model}_tau{args.tau}_te.npy"
                
                if "MeaCap" in args.caption_model or "default" in args.caption_model:
                    keys_and_text_file_path = os.path.join(resp_save_path, f"keys_and_text_{args.caption_model}_kmodel_{args.keywords_model}_{args.key_num}keys_{args.temperature}temp_{args.filter_th}th_{args.candidate_num}cands_cmodel_{args.correct_model}.json")
            
                    with open(keys_and_text_file_path, "r") as f:
                        keys_and_text = json.load(f)
                    
                    caption = keys_and_text["text"]
                        
                elif args.caption_model == "brainscuba":
                    if args.betas_norm:
                        caption_file_path = os.path.join(resp_save_path, f"caption_{args.caption_model}_tau{args.tau}_betanorm.txt")
                    else:
                        caption_file_path = os.path.join(resp_save_path, f"caption_{args.caption_model}_tau{args.tau}.txt")
                    with open(caption_file_path, "r") as f:
                        caption = f.read()

                print(caption)
                if args.caption_model in ["MeaCap", "default"]:
                    embs_save_path = keys_and_text_file_path.replace("keys_and_text", "all-MiniLM-L6-v2_embs").replace(".json", ".npy")
                else:
                    embs_save_path = caption_file_path.replace("caption", "all-MiniLM-L6-v2_embs").replace(".txt", ".npy")
                
                if os.path.exists(cc_sentence_sim_tr_file_path) and os.path.exists(cc_sentence_sim_te_file_path):
                    print(f"Already processed: {voxel_index}")
                    try:
                        caption_embs_list.append(np.load(embs_save_path))
                    except:
                        caption_embs_list.append(np.load(embs_save_path, allow_pickle=True))
                    try:
                        cc_tr = np.load(cc_sentence_sim_tr_file_path)
                        cc_te = np.load(cc_sentence_sim_te_file_path)
                        cc_dict_original["tr"].append(cc_tr)
                        cc_dict_original["te"].append(cc_te)
                        continue
                    except:
                        try:
                            cc_tr = np.load(cc_sentence_sim_tr_file_path, allow_pickle=True)
                            cc_te = np.load(cc_sentence_sim_te_file_path, allow_pickle=True)
                            cc_dict_original["tr"].append(cc_tr)
                            cc_dict_original["te"].append(cc_te)
                            continue
                        except:
                            pass
                    
            
                if os.path.exists(embs_save_path):
                    print(f"Already processed: {voxel_index}")
                    try:
                        caption_embs = np.load(embs_save_path)
                    except:
                        caption_embs = np.load(embs_save_path, allow_pickle=True)
                else:
                    caption_embs = wte_model.encode([caption]).squeeze()
                    np.save(embs_save_path, caption_embs)
                
                caption_embs_list.append(caption_embs)

                cos_sim_list_tr = []                
                cos_sim_list_te = []
                for i in range(annotator_num):
                    cos_sim_i_tr = sim_func(torch.tensor(caption_embs).to(args.device), torch.tensor(annots_embs_tr_all[i]).to(args.device)).cpu().numpy()
                    cos_sim_i_te = sim_func(torch.tensor(caption_embs).to(args.device), torch.tensor(annots_embs_te_all[i]).to(args.device)).cpu().numpy()
                    cos_sim_list_tr.append(cos_sim_i_tr)
                    cos_sim_list_te.append(cos_sim_i_te)
                cos_sim_tr_mean = np.mean(cos_sim_list_tr, axis=0)
                cos_sim_te_mean = np.mean(cos_sim_list_te, axis=0)
                print(cos_sim_tr_mean.shape)


                if args.cc_method == "pearson":
                    cc_tr = correlation_score(resp_trn, cos_sim_tr_mean)
                    cc_te = correlation_score(resp_val, cos_sim_te_mean)
                elif args.cc_method == "spearman":
                    cc_tr = scipy.stats.spearmanr(resp_trn, cos_sim_tr_mean).statistic
                    cc_te = scipy.stats.spearmanr(resp_val, cos_sim_te_mean).statistic
                print(cc_tr, cc_te)

                if args.caption_model in ["MeaCap", "default"]:
                    np.save(cc_sentence_sim_tr_file_path, cc_tr)
                    np.save(cc_sentence_sim_te_file_path, cc_te)
                else:
                    np.save(cc_sentence_sim_tr_file_path, cc_tr)
                    np.save(cc_sentence_sim_te_file_path, cc_te)
                
                cc_dict_original["tr"].append(cc_tr)
                cc_dict_original["te"].append(cc_te)
            
                        
            finally:
                try:
                    os.remove(temp_file_path)
                except:
                    pass
                    
        atlasname_savename = "_".join(args.atlasname)

        if args.caption_model in ["MeaCap", "default"]:
            all_cc_savename = f"{atlasname_savename}_{args.cc_method}_cc_sentence_sim_{args.caption_model}_kmodel_{args.keywords_model}_{args.key_num}keys_{args.temperature}temp_{args.filter_th}th_{args.candidate_num}cands_cmodel_{args.correct_model}.npy"
        
        else:
            if args.betas_norm:
                all_cc_savename = f"{atlasname_savename}_{args.cc_method}_cc_sentence_sim_{args.caption_model}_tau{args.tau}_betanorm.npy"
            else:
                all_cc_savename = f"{atlasname_savename}_{args.cc_method}_cc_sentence_sim_{args.caption_model}_tau{args.tau}.npy"

        all_voxels_cc_save_path = f"./data/nsd/insilico/{subject_name}/{args.dataset_name}_{args.max_samples}/{modality}/{modality_hparam}/{model_name}_{make_filename(args.reduce_dims[0:2])}/whole/{all_cc_savename}"

        cc_dict_original["tr"] = np.array(cc_dict_original["tr"])
        cc_dict_original["te"] = np.array(cc_dict_original["te"])
        np.save(all_voxels_cc_save_path, cc_dict_original)

        # perf type = block
        pvalue_corrected_dict_original = {"tr": [], "te": []}
        for trnval in ["tr", "te"]:
            if trnval == "tr":
                n_sample = resp.trn.shape[0]
            else:
                n_sample = resp.val.shape[0]
            y_val_pred = cc_dict_original[trnval]
            rccs = gen_nulldistrib_gauss(len(volume_index), n_sample)        
            significant_voxels, pvalue_corrected = fdr_correction(cc_dict_original[trnval], rccs)
            print(f"Number of significant voxels: {len(significant_voxels)}")
            print(f"pvalue_corrected: {pvalue_corrected}")
            pvalue_corrected_dict_original[trnval] = pvalue_corrected
        np.save(f"{all_voxels_cc_save_path.replace('cc_sentence_sim_', 'cc_pvalues_corrected_sentence_sim_')}", pvalue_corrected_dict_original)


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument(
        "--subject_names",
        nargs="*",
        type=str,
        required=True,
    )
    parser.add_argument(
        "--atlasname",
        nargs="*",
        type=str,
        required=True,
    )
    parser.add_argument(
        "--betas_norm",
        action="store_true"
    )
    parser.add_argument(
        "--modality",
        type=str,
        required=True,
        help="Name of the modality to use."
    )
    parser.add_argument(
        "--modality_hparam",
        type=str,
        required=True,
    )
    parser.add_argument(
        "--model_name",
        type=str,
        required=True,
    )
    parser.add_argument(
        "--reduce_dims",
        nargs="*",
        type=str,
        required=True,
    )
    parser.add_argument(
        "--dataset_name",
        type=str,
        required=True,
    )
    parser.add_argument(
        "--max_samples",
        type=str,
        required=True,
    )
    parser.add_argument(
        "--dataset_path",
        type=str,
        required=True,
    )
    parser.add_argument(
        "--dataset_captioner",
        type=str,
        required=True,
    )
    parser.add_argument(
        "--voxel_selection",
        nargs="*",
        type=str,
        required=True,
        help="Selection method of voxels. Implemented type are 'uv' and 'share'."
    )
    parser.add_argument(
        "--layer_selection",
        type=str,
        required=False,
        default="best",
    )
    parser.add_argument(
        "--caption_model",
        type=str,
        required=True,
        help="Name of the captioning model to use."
    )
    parser.add_argument(
        "--keywords_model",
        type=str,
        required=False,
    )
    parser.add_argument(
        "--correct_model",
        type=str,
        required=True,
        choices=["None", "default", "gpt-4o-2024-08-06", 'gpt-4o-mini-2024-07-18'],
        help="Name of the correction model to use."
    )
    parser.add_argument(
        "--candidate_num",
        type=int,
        required=True,
    )
    parser.add_argument(
        "--key_num",
        type=int,
        required=True,
    )
    parser.add_argument(
        "--temperature",
        type=float,
        required=True
    )
    parser.add_argument(
        "--filter_th",
        type=float,
        required=False,
    )
    parser.add_argument(
        "--tau",
        type=float,
        required=False
    )
    parser.add_argument(
        "--cc_method",
        type=str,
        required=True,
        choices=["spearman", "pearson"],
    )
    parser.add_argument(
        "--device",
        type=str,
        required=True,
        help="Device to use."
    )
    parser.add_argument(
        "--embs_only",
        action="store_true",
        required=False,
        default=False
    )
    args = parser.parse_args()
    main(args)